Week 2: Linear models and causal inference

Categories and curves

Workspace setup:

As we develop more useful models, we’ll begin to practice the art of generating models with multiple estimands. An estimand is a quantity we want to estimate from the data. Our models may not themselves produce the answer to our central question, so we need to know how to calculate these values from the posterior distributions.

Categories

Forget dummy codes. From here on out, we will incorporate categorical causes into our models by using index variables. An index variable contains integers that correspond to different categories. The numbers have no inherent meaning – rather, they stand as placeholders or shorthand for categories.

data("Howell1")
d <- Howell1
d$sex <- ifelse(d$male == 1, 2, 1) # 1 = female, 2 = male
head(d[, c("male", "sex")])
  male sex
1    1   2
2    0   1
3    0   1
4    1   2
5    0   1
6    1   2

Mathematical model

Let’s write a mathematical model to express height in terms of sex.

\[\begin{align*} h_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{SEX[i]} \\ \alpha_j &\sim \text{Normal}(178, 20)\text{ for }j = 1..2 \\ \sigma &\sim \text{Uniform}(0, 50) \end{align*}\]

flist <- alist(
  height ~ dnorm( mu , sigma) ,
  mu <- a[sex] ,
  a[sex] ~ dnorm( 178, 20 ) ,
  sigma ~ dunif(0, 50)
)

Fitting the model using quap()

m1 <- quap(
  flist, data=d
)

precis(m1, depth=2)
           mean        sd      5.5%     94.5%
a[1]  134.90961 1.6068759 132.34151 137.47771
a[2]  142.57765 1.6974105 139.86486 145.29044
sigma  27.30899 0.8279682  25.98574  28.63224

Here, we are given the estimates of the parameters specified in our model: the average height of women (a[1]) and the average height of men (a[2]). But our question is whether these average heights are different. How do we get that?

post <- extract.samples( m1 )
str(post)
List of 2
 $ sigma: num [1:10000] 26.6 28.4 28.3 27 27 ...
 $ a    : num [1:10000, 1:2] 135 134 136 135 133 ...
 - attr(*, "source")= chr "quap posterior: 10000 samples from m1"
head(post$a)
         [,1]     [,2]
[1,] 134.9650 143.2816
[2,] 134.0555 141.4182
[3,] 136.4525 144.1724
[4,] 134.8229 141.5401
[5,] 132.8612 143.5772
[6,] 136.8333 146.8423
post$diff_fm <- post$a[,1] - post$a[,2]
precis(post, depth=2 )
             mean        sd      5.5%      94.5%      histogram
sigma    27.30203 0.8271165  25.98077  28.628084 ▁▁▁▁▃▅▇▇▃▂▁▁▁▁
a[1]    134.92183 1.6053494 132.30030 137.489126  ▁▁▁▁▂▅▇▇▅▂▁▁▁
a[2]    142.56643 1.6954866 139.85725 145.229541 ▁▁▁▂▃▇▇▇▃▂▁▁▁▁
diff_fm  -7.64460 2.3182790 -11.34290  -3.959857     ▁▁▁▂▇▇▃▁▁▁

Calculate the contrast

We can create two plots. One is the posterior distributions of average female and male heights and one is the average difference.

p1 <- post %>% as.data.frame() %>% 
  pivot_longer(starts_with("a")) %>% 
  mutate(sex = ifelse(name == "a.1", "female", "male")) %>% 
  ggplot(aes(x=value, color = sex)) +
  geom_density(linewidth = 2) +
  labs(x = "height(cm)") 

p2 <- post %>% as.data.frame() %>% 
  ggplot(aes(x=diff_fm)) +
  geom_density(linewidth = 2) +
  labs(x = "difference in height(cm)") 

( p1 | p2)

Expected values vs predicted values

A note that the distributions of the mean heights is not the same as the distribution of heights period. For that, we need the posterior predictive distributions.

pred_f  <- rnorm(1e4, mean = post$a[,1], sd = post$sigma )
pred_m  <- rnorm(1e4, mean = post$a[,2], sd = post$sigma )

pred_post = data.frame(pred_f, pred_m) %>%
  mutate(diff = pred_f-pred_m)

# plot distributions
p1 <- pred_post %>% pivot_longer(starts_with("pred")) %>% 
  mutate(sex = ifelse(name == "pred_f", "female", "male")) %>% 
  ggplot(aes(x = value, color = sex)) +
  geom_density(linewidth = 2) +
  labs(x = "height (cm)")

# plot difference
# Compute density first
density_data <- density(pred_post$diff)

# Convert to a tibble for plotting
density_df <- tibble(
  x = density_data$x,
  y = density_data$y,
  fill_group = ifelse(x < 0, "male", "female")  # Define fill condition
)

# Plot with area fill
p2 <- ggplot(density_df, aes(x = x, y = y, fill = fill_group)) +
  geom_area() +  # Adjust transparency if needed
  geom_line(linewidth = 1.2, color = "black") +  # Keep one continuous curve
  labs(x = "Difference in height (F-M)", y = "density") +
  guides(fill = "none")

(p1 | p2)

exercise

In the rethinking package, the dataset milk contains information about the composition of milk across primate species, as well as some other facts about those species. The taxonomic membership of each species is included in the variable clade; there are four categories.

  1. Create variable in the dataset to assign an index value to each of the 4 categories.
  2. Standardize the milk energy variable (kcal.per.g). 1
  3. Write a mathematical model to express the average milk energy (in standardized kilocalories) in each clade.

solution

data("milk")
str(milk)
'data.frame':   29 obs. of  8 variables:
 $ clade         : Factor w/ 4 levels "Ape","New World Monkey",..: 4 4 4 4 4 2 2 2 2 2 ...
 $ species       : Factor w/ 29 levels "A palliata","Alouatta seniculus",..: 11 8 9 10 16 2 1 6 28 27 ...
 $ kcal.per.g    : num  0.49 0.51 0.46 0.48 0.6 0.47 0.56 0.89 0.91 0.92 ...
 $ perc.fat      : num  16.6 19.3 14.1 14.9 27.3 ...
 $ perc.protein  : num  15.4 16.9 16.9 13.2 19.5 ...
 $ perc.lactose  : num  68 63.8 69 71.9 53.2 ...
 $ mass          : num  1.95 2.09 2.51 1.62 2.19 5.25 5.37 2.51 0.71 0.68 ...
 $ neocortex.perc: num  55.2 NA NA NA NA ...
milk$clade_id <- as.integer(milk$clade)
milk$K <- standardize(milk$kcal.per.g)

\[\begin{align*} K_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{\text{CLAUDE}[i]} \\ \alpha_i &\sim \text{Normal}(0, 0.5) \text{ for }j=1..4 \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

Exercise: Now fit your model using quap(). It’s ok if your mathematical model is a bit different from mine.

solution

flist <- alist(
  K ~ dnorm( mu , sigma ) ,
  mu <- a[clade_id] , 
  a[clade_id] ~ dnorm( 0 , 0.5 ) , 
  sigma ~ dexp( 1 )
)

m2 <- quap(
  flist, data = milk
)

precis( m2, depth=2 )
            mean         sd        5.5%      94.5%
a[1]  -0.4843500 0.21764119 -0.83218268 -0.1365174
a[2]   0.3662535 0.21705892  0.01935142  0.7131556
a[3]   0.6752199 0.25753413  0.26363060  1.0868092
a[4]  -0.5858078 0.27450906 -1.02452627 -0.1470893
sigma  0.7196454 0.09653344  0.56536635  0.8739245

Plotting with rethinking

labels <- paste( "a[" , 1:4, "]:", levels(milk$clade),  sep="" )
plot(
  precis(m2, depth=2, pars = "a"),
  labels=labels, 
  xlab="expected kcal (std)"
)

exercise

Plot the following distributions:

  • Posterior distribution of average milk energy by clade.
  • Posterior distribution of predicted milk energy values by clade.

solution

post <- extract.samples( m2 )
names(labels) = paste("a.", 1:4, sep = "")
post %>% as.data.frame() %>% 
  pivot_longer(starts_with("a")) %>% 
  mutate(name = recode(name, !!!labels)) %>% 
  ggplot(aes(x = value, color = name)) +
  geom_density(linewidth = 2) +
  labs(title = "Posterior distribution of expected milk energy")

solution

post <- extract.samples( m2 )
a.1 = rnorm(1e4, post$a[,1], post$sigma)
a.2 = rnorm(1e4, post$a[,2], post$sigma)
a.3 = rnorm(1e4, post$a[,3], post$sigma)
a.4 = rnorm(1e4, post$a[,4], post$sigma)
data.frame(a.1, a.2, a.3, a.4) %>% 
  pivot_longer(everything()) %>% 
  mutate(name = recode(name, !!!labels)) %>% 
  ggplot(aes(x = value, color = name)) +
  geom_density(linewidth = 2) +
  labs(title = "Posterior distribution of predicted milk energy")

Combining index variables and slopes

Let’s return to the height example. What if we want to control for weight?

\[\begin{align*} h_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{S[i]} + \beta_{S[i]}(W_i-\bar{W})\\ \alpha_j &\sim \text{Normal}(178, 20)\text{ for }j = 1..2 \\ \beta_j &\sim \text{Normal}(0, 5)\text{ for }j = 1..2 \\ \sigma &\sim \text{Uniform}(0, 50) \end{align*}\]

dat <- list(
  height = d$height,
  weight = d$weight,
  Wbar <- mean(d$weight),
  sex = d$male + 1
)

flist <- alist(
  height ~ dnorm( mu , sigma) ,
  mu <- a[sex] + b[sex]*(weight-Wbar),
  a[sex] ~ dnorm( 178, 20 ) ,
  b[sex] ~ dnorm( 0, 20 ) ,
  sigma ~ dunif(0, 50)
)

m3 <- quap(flist, data=dat)
precis(m3, depth=3)
            mean         sd       5.5%      94.5%
a[1]  138.565626 0.55777101 137.674200 139.457052
a[2]  138.162933 0.58854642 137.222322 139.103543
b[1]    1.805475 0.04138460   1.739334   1.871615
b[2]    1.735288 0.03688952   1.676331   1.794244
sigma   9.330414 0.28287280   8.878329   9.782499
post <- extract.samples(m3)
str(post)
List of 3
 $ sigma: num [1:10000] 9.33 9.22 9.63 9.33 9.4 ...
 $ a    : num [1:10000, 1:2] 138 138 138 138 138 ...
 $ b    : num [1:10000, 1:2] 1.81 1.75 1.73 1.76 1.85 ...
 - attr(*, "source")= chr "quap posterior: 10000 samples from m3"

Plot the slopes using extract.samples()

xbar = mean(d$weight)
plot(NULL, xlim = range(d$weight), ylim = c(0, 200),
     xlab = "weight", ylab = "height")
#plot each line
for(i in 1:50){
 curve(post$a[i, 1] +post$b[i, 1]*(x-xbar), 
       add = T,
       col=col.alpha("#1c5253",0.1))  
  curve(post$a[i, 2] +post$b[i, 2]*(x-xbar), 
       add = T,
       col=col.alpha("#e07a5f",0.1))  
}

Plot the slopes using link()

xseq <- seq( min(d$weight), max(d$weight), len=100)
plot(NULL, xlim = range(d$weight), ylim = range(d$height), xlab = "weight", ylab = "height")
muF <- 
  link(m3, data=list(sex=rep(1,100), weight=xseq, Wbar = mean(d$weight)))
lines(xseq, apply(muF, 2, mean), lwd = 2, col = "#1c5253" )
muM <- 
  link(m3, data=list(sex=rep(2,100), weight=xseq, Wbar = mean(d$weight)))
lines(xseq, apply(muM, 2, mean), lwd = 2, col =  "#e07a5f")

exercise

Return to the milk data. Write a mathematical model expressing the energy of milk as a function of the species body mass (mass) and clade category. Be sure to include priors. Fit your model using quap().

solution

\[\begin{align*} K_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{\text{CLAUDE}[i]} + \beta_{\text{CLAUDE}[i]}(M-\bar{M})\\ \alpha_i &\sim \text{Normal}(0, 0.5) \text{ for }j=1..4 \\ \beta_i &\sim \text{Normal}(0, 0.5) \text{ for }j=1..4 \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

dat <- list(
  K        = standardize(milk$kcal.per.g),
  M        = milk$mass,
  Mbar     = mean(milk$mass),
  clade_id = milk$clade_id
)

flist <- alist(
  K ~ dnorm( mu , sigma ) ,
  mu <- a[clade_id] +b[clade_id]*(M-Mbar), 
  a[clade_id] ~ dnorm( 0 , 0.5 ) , 
  b[clade_id] ~ dnorm( 0 , 0.5 ) , 
  sigma ~ dexp( 1 )
)

m4 <- quap(
  flist, data = dat
)
precis( m4, depth=2 )
              mean          sd         5.5%        94.5%
a[1]  -0.434263963 0.261118910 -0.851582415 -0.016945512
a[2]  -0.282793829 0.478554245 -1.047615940  0.482028282
a[3]   0.368822352 0.418885791 -0.300638045  1.038282749
a[4]  -0.005061035 0.498109188 -0.801135722  0.791013652
b[1]  -0.002670551 0.007183316 -0.014150878  0.008809775
b[2]  -0.061303611 0.040137284 -0.125450743  0.002843521
b[3]  -0.047926482 0.050277719 -0.128279987  0.032427024
b[4]   0.064917378 0.046232330 -0.008970815  0.138805572
sigma  0.692511084 0.092024929  0.545437474  0.839584694
xseq <- seq( min(milk$mass), max(milk$mass), len=100)
Mbar = mean(milk$mass)
custom_colors = c("#1c5253", "#e07a5f", "#f2cc8f", "#81b29a")
colors = custom_colors[milk$clade_id]
plot(milk$K ~ milk$mass, col = colors, 
     pch = 16,
     xlim = range(milk$mass), ylim = range(milk$K), 
     xlab = "weight", ylab = "height")
mu1 <- 
  link(m4, data=list(clade_id=rep(1,100), M=xseq, Mbar = Mbar))
lines(xseq, apply(mu1, 2, mean), lwd = 2, col = "#1c5253" )
mu2 <- 
  link(m4, data=list(clade_id=rep(2,100), M=xseq, Mbar = Mbar))
lines(xseq, apply(mu2, 2, mean), lwd = 2, col = "#e07a5f" )
mu3 <- 
  link(m4, data=list(clade_id=rep(3,100), M=xseq, Mbar = Mbar))
lines(xseq, apply(mu3, 2, mean), lwd = 2, col = "#f2cc8f" )
mu4 <- 
  link(m4, data=list(clade_id=rep(4,100), M=xseq, Mbar = Mbar))
lines(xseq, apply(mu4, 2, mean), lwd = 2, col = "#81b29a" )
legend("topright", legend = levels(milk$clade), 
       col = custom_colors, pch = 16)

Splines

data(cherry_blossoms)
d <- cherry_blossoms
precis(d)
                  mean          sd      5.5%      94.5%       histogram
year       1408.000000 350.8845964 867.77000 1948.23000   ▇▇▇▇▇▇▇▇▇▇▇▇▁
doy         104.540508   6.4070362  94.43000  115.00000        ▁▂▅▇▇▃▁▁
temp          6.141886   0.6636479   5.15000    7.29470        ▁▃▅▇▃▂▁▁
temp_upper    7.185151   0.9929206   5.89765    8.90235 ▁▂▅▇▇▅▂▂▁▁▁▁▁▁▁
temp_lower    5.098941   0.8503496   3.78765    6.37000 ▁▁▁▁▁▁▁▃▅▇▃▂▁▁▁
psych::describe(d)
           vars    n    mean     sd  median trimmed    mad    min     max
year          1 1215 1408.00 350.88 1408.00 1408.00 450.71 801.00 2015.00
doy           2  827  104.54   6.41  105.00  104.54   5.93  86.00  124.00
temp          3 1124    6.14   0.66    6.10    6.11   0.61   4.67    8.30
temp_upper    4 1124    7.19   0.99    7.04    7.10   0.92   5.45   12.10
temp_lower    5 1124    5.10   0.85    5.14    5.10   0.72   0.75    7.74
             range  skew kurtosis    se
year       1214.00  0.00    -1.20 10.07
doy          38.00  0.00    -0.15  0.22
temp          3.63  0.40     0.11  0.02
temp_upper    6.65  1.05     1.71  0.03
temp_lower    6.99 -0.17     1.88  0.03
d2 <- d[ complete.cases(d$doy) , ] # complete cases on doy
num_knots <- 15
knot_list <- quantile( d2$year , probs=seq(0,1,length.out=num_knots) )
knot_list
       0% 7.142857% 14.28571% 21.42857% 28.57143% 35.71429% 42.85714%       50% 
      812      1036      1174      1269      1377      1454      1518      1583 
57.14286% 64.28571% 71.42857% 78.57143% 85.71429% 92.85714%      100% 
     1650      1714      1774      1833      1893      1956      2015 
d2 %>% 
  ggplot(aes(x = year, y = doy)) +
  geom_point(color = "#ffb7c5", alpha = 1/2) + 
  geom_vline(xintercept = knot_list, color = "white", alpha = .5) +
  theme(panel.grid = element_blank(),
        panel.background = element_rect(fill = "#4f455c"))
library(splines)
B <- bs(d2$year,
  knots=knot_list[-c(1,num_knots)] ,
  degree=3 , intercept=TRUE )

plot( NULL , xlim=range(d2$year) , ylim=c(0,1) , xlab="year" , ylab="basis" )
for ( i in 1:ncol(B) ) lines( d2$year , B[,i] )

Mathematical model

\[\begin{align*} D_i &\sim \text{Normal}(\mu_i,\sigma) \\ \mu_i &= \alpha + \sum^K_{k=1} w_k B_{k,i} \\ \alpha &\sim \text{Normal}(100,10) \\ w_j &\sim \text{Normal}(0,10) \\ \sigma &\sim \text{Exponential}(1) \end{align*}\]

m4.7 <- quap(

  alist(
  D ~ dnorm( mu , sigma ) ,
  mu <- a + B %*% w ,
  a ~ dnorm(100,10),
  w ~ dnorm(0,10),
  sigma ~ dexp(1)), 
  
  data=list( D=d2$doy , B=B ) ,
  
  start=list( w=rep( 0 , ncol(B) ) ) )